# get colors for plots, total 63 colors can use
get_colors <- function(colors_len = NULL){
  cols <- c(RColorBrewer::brewer.pal(12,"Paired"),
           RColorBrewer::brewer.pal(9,"Set1"),
           RColorBrewer::brewer.pal(8,"Set2"),
           RColorBrewer::brewer.pal(12,"Set3"),
           RColorBrewer::brewer.pal(8,"Pastel2"),
           RColorBrewer::brewer.pal(9,"Pastel1"),
           RColorBrewer::brewer.pal(8,"Accent"))
  if(missing(colors_len)){
    col <- unique(cols)[-c(17,23)]
  }else{
    col <- unique(cols)[-c(17,23)][1:colors_len]
  }
  return(col)
}

# 自定义 34 种颜色
col <- c("#b5b6db","#156e8a","#d5dea1","#a7559c","#3e64a3","#a4d7e3","#6ca0d6",
         "#3b64ad","#e4dcc0","#f19570","#ebd6e8","#f093aa","#e23626","#0d783d",
         "#7fbe70","#cedff0","#4da1d1","#1b4279","#87ab3f","#a1ca7b","#91398f",
         "#a680ba","#c29371","#f3ddeb","#d33a6c","#b97782","#194791","#de7b91",
         "#e59cc4","#badfe9","#852e8a","#efc0da","#7d85ba","#4ab6b0")

# extra reduction data for use ggplot2 to plot
get_reduction_data <- function(project = NULL,
                               reduction = 'tsne'){
  library(dplyr)
  reduction_data <- data.frame(Seurat::Embeddings(project, reduction = reduction)) %>% 
    cbind(project@meta.data) %>% 
    tibble::rownames_to_column(var = "cell")
  return(reduction_data)
}

# draw main plot 绘制tSNE 或 UMAP图
# cluster_type 是指使用什么对图形进行着色, show_legend 指是否显示图例, show_labels 是指图形上是否展示 cluster_type
# arrow_type 可以为 closed 或者 open, 决定的是箭头的类型,  tidydr::theme_dr() 可以将图形的横纵轴仅展示左下角。
# 关于数据 第一列为 barcode 第二三列为 tsne_1 tsne_2 或 umap_1 umap_2
draw_main <- function(reduction_data = NULL,
                      cluster_type = "Cluster",
                      show_legend = TRUE,
                      show_labels = FALSE,
                      point_size = 0.5, 
                      base_size = 14,
                      key_size = 5,
                      label_size = 5,
                      adjust_axis = FALSE,
                      arrow_type = "closed"){
  library(dplyr)
  # 如果传入的着色列存在空白,即未定义的 Cluster,先把空的赋值为 NA
  reduction_data[[cluster_type]][reduction_data[[cluster_type]] == ""] <- NA
  if(is.numeric(reduction_data[[cluster_type]])){
    reduction_data[[cluster_type]] <- as.factor(reduction_data[[cluster_type]])
  }
  # 进行 scale_color_manual 排除掉 NA
  breaks <- sort(unique(reduction_data[[cluster_type]])[!(is.na(unique(reduction_data[[cluster_type]])))])
  p <- ggplot2::ggplot() +
    ggplot2::geom_point(data = reduction_data,
                        ggplot2::aes_string(x = colnames(reduction_data)[2],
                                            y = colnames(reduction_data)[3],
                                            color = cluster_type),
                        size = point_size,show.legend = show_legend)+
    ggplot2::scale_color_manual(breaks = breaks,
                                values = get_colors(colors_len = length(breaks)),
                                guide = ggplot2::guide_legend(override.aes = list(size = key_size)))+
    ggplot2::theme_classic(base_size = base_size)
  if(isTRUE(adjust_axis)){
    p <- p + tidydr::theme_dr(xlength = 0.3,ylength = 0.3,arrow = grid::arrow(length = ggplot2::unit(0.15, "inches"), type = arrow_type)) +
      ggplot2::theme(panel.border = ggplot2::element_blank(),
                     panel.grid.major = ggplot2::element_blank(),
                     panel.grid.minor = ggplot2::element_blank())
  }
  if(isTRUE(show_labels) && (cluster_type == "Cluster")){
    label_pos <- reduction_data %>% 
      dplyr::group_by(.data[[cluster_type]]) %>% 
      dplyr::summarise(x_pos = median(get(colnames(reduction_data)[2])),
                       y_pos = median(get(colnames(reduction_data)[3]))) %>% 
      dplyr::filter(.data[[cluster_type]] != "") %>% tidyr::drop_na()
    p <- p + ggplot2::geom_text(data = label_pos,ggplot2::aes_string(x = "x_pos",y = "y_pos",label = cluster_type),size = label_size)
  }
  if(isTRUE(show_labels) && !(cluster_type == "Cluster")){
    label_pos <- reduction_data %>% 
      dplyr::group_by(Cluster,.data[[cluster_type]]) %>% 
      dplyr::summarise(x_pos = median(get(colnames(reduction_data)[2])),
                       y_pos = median(get(colnames(reduction_data)[3]))) %>% 
      dplyr::filter(.data[[cluster_type]] != "") %>% tidyr::drop_na()
    label_index <- duplicated(label_pos[[cluster_type]])
    label_pos <- label_pos[!label_index,]
    p <- p + ggplot2::geom_text(data = label_pos,ggplot2::aes_string(x = "x_pos",y = "y_pos",label = cluster_type),size = label_size)
  }
  return(p)
}


# prep draw legend data and plot
# function 是用来绘制圆圈数字的图例,function 返回一个列表,包含图 和 绘图数据
# col_num 为圆圈数字的列数,常见的文献一般是三列,可以对其进行设置，会按需求画出圆圈数字,但是对其的注释会出现 NA,所以建议是不对其进行设置,老师需要可以自行使用AI修改。
# all_cluster_num 参数是为了避免当只选其中部分cluster进行展示图例时，主图展示34个cluster，但图例只展示5类9个Cluster，两张图的颜色对不上的尴尬，all_cluster_num 是数据中全部cluster num，当绘制的cluster数量与图例中展示的数量一致时，填需要展示的数就行。
draw_legend <- function(data = NULL,
                        cluster = "Cluster",
                        all_cluster_num = NULL,
                        celltype = "CellType",
                        col_num = 10){
    library(dplyr)
    legend_data <- data %>% dplyr::select(cluster,celltype) %>% unique()
    colnames(legend_data) <- c("Cluster","CellType")
    legend_data$Cluster <- as.factor(legend_data$Cluster)
    legend_data <- legend_data %>% dplyr::arrange(CellType) %>% dplyr::filter(CellType != "") %>% tidyr::drop_na()
    rep_times <- legend_data %>% dplyr::group_by(CellType) %>% dplyr::count(CellType) %>% .$n
    legend_data <- legend_data %>% dplyr::arrange(CellType,Cluster)
    legend_data$times <- rep(rep_times,times = rep_times)
    x_tmp <- c()
    for(i in rep_times){
        if(i > col_num){
            tmp_vector <- rep(1:col_num,times = floor(i/col_num))
            tmp_vector <- c(tmp_vector,1:(i - col_num * floor(i/col_num)))
        }else{
            tmp_vector <- c(1:i)
        }
        x_tmp <- c(x_tmp,tmp_vector)
    }
    legend_data$x <- x_tmp
    legend_data <- legend_data %>% dplyr::arrange(desc(times))
    rep_times <- c(0)
    for(i in legend_data$x){
        rep_times <- c(rep_times,i)
        if(i > rep_times[length(rep_times)-1]){
            rep_times <- rep_times[-(length(rep_times)-1)]
        }
    }
    legend_data$y <- rep(sum(legend_data$x == 1):1,times = rep_times)
    legend_data$y <- as.factor(legend_data$y)
    legend_fig <- ggplot2::ggplot(data = legend_data,ggplot2::aes(x=x,y=y)) +
        ggplot2::geom_point(ggplot2::aes(color=Cluster),
                            show.legend = F,
                            size = 5)+
        ggplot2::scale_y_discrete(labels = rev(unique(legend_data$CellType))) +
        ggplot2::scale_color_manual(breaks = sort(legend_data$Cluster),
                                    values = get_colors(colors_len = all_cluster_num)[sort(legend_data$Cluster)]) +
        ggplot2::geom_text(ggplot2::aes(label = Cluster),size = 4) +
        ggplot2::theme(axis.line = ggplot2::element_blank(),
                       axis.text.x = ggplot2::element_blank(),
                       axis.title.x = ggplot2::element_blank(),
                       axis.title.y = ggplot2::element_blank(),
                       panel.background = ggplot2::element_blank(),
                       axis.text.y = ggplot2::element_text(size = 12),
                       axis.ticks.length = ggplot2::unit(0, "pt"))
    result <- list(legend_data = legend_data,legend_fig = legend_fig)
    return(result)
}

# 两张图合适的拼接大小比例是 5：2 保存图片大小为 8 5
# p <- tsne_fig + result$legend_fig + plot_layout(widths = c(5, 2))
# ggsave(p,filename = "tsne_beautif_3.png",width = 8,height = 5,dpi = 300)
# 例图见：tsne_umap_beautif.png

#------------------------------------------------分割线------------------------------------------------#
plot_gene_dot <- function(project = NULL, geneName = NULL, expression = 0,
                          barcode_list = NULL, fig = NULL, reduction_data = NULL,
                          dot_size = 1, key_size = 5){
  library(dplyr)
  if(is.null(barcode_list)){
    barcode_list <- Seurat::FetchData(object = project, vars = geneName)%>% 
      filter(.[[1]] > expression) %>% rownames(.)
  }
  data <- base::subset(reduction_data, cell %in% barcode_list)
  data$desc <- paste0(geneName,"-expressing cells")
  p <- p + ggplot2::geom_point(data = data,
                                 ggplot2::aes_string(x = colnames(data)[2],
                                                     y = colnames(data)[3],
                                                     fill = colnames(data)[length(colnames(data))]),
                                 shape = 21,size = dot_size, color = 'black') +
    ggplot2::scale_fill_manual(values = '#504aa8')+
    ggplot2::guides(fill=guide_legend(title=NULL,override.aes = list(size = key_size)))
  return(p)
}

# plot_gene_dot function 可以在 tsne 或 umap 图上展示出表达某个基因的细胞
# 参数说明：
# project         Seurat 对象
# geneName        要展示的基因名称
# expression      展示基因的表达量最低值
# barcode_list    如果不提供 基因名 和 表达最低值, 也可以提供细胞的 barcode 列表, 是一个向量。
# fig             tsne 或 umap 主图
# reduction_data  使用 get_reduction_data 得到的数据
# dot_size        图中点的大小
# key_size        图例中点的大小

# data <- get_reduction_data(project=project,reduction='tsne')
# p <- draw_main(reduction_data=data,cluster_type="Cell_Type",show_labels=T,adjust_axis=F)
# p1 <- plot_gene_dot(project=project,geneName="IL33",expression=2,fig=p,reduction_data=data)
#------------------------------------------------分割线------------------------------------------------#

draw_number_circle <- function(data,params,size){
  grid::grobTree(grid::pointsGrob(x = 0.5,y = 0.5,
                                  size = unit(1.5,"char"),
                                  pch = 16,
                                  gp = grid::gpar(col = alpha(data$colour %||% "grey50",data$alpha),
                                            fill = alpha(data$fill %||% "grey50",data$alpha),
                                            lwd = (data$linewidth %||% 0.5)* .pt,
                                            lty = data$linetype %||% 1)),
                 grid::textGrob(label = data$label,
                                x = rep(0.5,3),y = rep(0.5,3),
                                gp = grid::gpar(col = "black"))
  )
  
}

plot_number_circle <- function(reduction_data  = NULL,
                      colors_by = "CellType",
                      labels_by = "Cluster",
                      show_legend = TRUE,
                      show_labels = FALSE,
                      point_size = 0.5, 
                      base_size = 12,
                      key_size = 5,
                      label_size = 5,
                      adjust_axis = FALSE,
                      arrow_type = "closed"){
  library(dplyr)
  reduction_data[[colors_by]][reduction_data[[colors_by]] == ""] <- NA
  if(is.numeric(reduction_data[[colors_by]])){
    reduction_data[[colors_by]] <- as.factor(reduction_data[[colors_by]])
  }
  breaks <- sort(unique(reduction_data[[colors_by]])[!(is.na(unique(reduction_data[[colors_by]])))])
  p <- ggplot2::ggplot() +
    ggplot2::geom_point(key_glyph = draw_number_circle,
                        data = reduction_data,
                        ggplot2::aes_string(x = colnames(reduction_data)[2],
                                            y = colnames(reduction_data)[3],
                                            color = colors_by),
                        size = point_size,show.legend = show_legend) +
    ggplot2::scale_color_manual(breaks = breaks,
                                values = get_colors(colors_len = length(breaks)),
                                guide = ggplot2::guide_legend(override.aes = list(size = key_size)))+
    ggplot2::theme_classic(base_size = base_size)
  if(isTRUE(adjust_axis)){
    p <- p + tidydr::theme_dr(xlength = 0.3,ylength = 0.3,
                              arrow = grid::arrow(length = ggplot2::unit(0.15, "inches"), 
                                                  type = arrow_type)) +
      ggplot2::theme(panel.border = ggplot2::element_blank(),
                     panel.grid.major = ggplot2::element_blank(),
                     panel.grid.minor = ggplot2::element_blank())
  }
  if(isTRUE(show_labels)){
    label_pos <- reduction_data %>% 
      dplyr::group_by(.data[[labels_by]]) %>% 
      dplyr::summarise(x_pos = median(get(colnames(reduction_data)[2])),
                       y_pos = median(get(colnames(reduction_data)[3]))) %>% 
      dplyr::filter(.data[[labels_by]] != "") %>% tidyr::drop_na()
    p <- p + ggplot2::geom_text(data = label_pos,
                                ggplot2::aes_string(x = "x_pos",y = "y_pos",
                                                    label = labels_by),
                                size = label_size) +
      guides(color = guide_legend(override.aes = list(label = label_pos[[1]])))
  }
  return(p)
}

# data <- data.table::fread("DRM_tsne_umap.txt") %>% filter(Cluster %in% c(1,2,3,4,5))
# data$CellType <- NA
# data$CellType[data$Cluster == "1"] <- "Th2"
# data$CellType[data$Cluster == "2"] <- "ILC1"
# data$CellType[data$Cluster == "3"] <- "DC2"
# data$CellType[data$Cluster == "4"] <- "Tc1"
# data$CellType[data$Cluster == "5"] <- "Treg"
# data$CellType <- factor(data$CellType,levels = unique(data$CellType)[order(unique(data$Cluster))])
# 
# p <- plot_number_circle(reduction_data = data,colors_by = "CellType",
#                labels_by = "Cluster",show_labels = T,
#                point_size = 1,adjust_axis = T,
#                label_size = 5)
# 
# plot_number_circle 绘制带颜色圆圈
# guides(color = guide_legend(override.aes = list(label = label_pos[[1]]))) 在圆圈上加数字
# 一定要注意的是 CellType 一定要记得将其转成 factor, 注意顺序, 一定要将 CellType 的 factor 顺序要和 Cluster 的顺序要对对应
#------------------------------------------------分割线------------------------------------------------#